#include "shmhdl.hpp"
#define CATCH_CONFIG_MAIN
#include <catch2/catch.hpp>
#include <chrono>
#include <random>
#include <thread>

using namespace std::chrono_literals;
namespace libshm = shm_kernel::shared_memory;

TEST_CASE("create shm_handle", "[create]")
{
  libshm::shm_handle handle("test_shm", 4096);
  REQUIRE(handle.fd() > 0);
  REQUIRE(handle.addr() == nullptr);
  REQUIRE(handle.ref_count() == 1);
  REQUIRE(handle.nbytes() >= 4096);
}

TEST_CASE("create shm_handle and destroy it", "[destructor]")
{
  {
    libshm::shm_handle handle("test_shm", 4096);
    REQUIRE(handle.fd() > 0);
    REQUIRE(handle.addr() == nullptr);
    REQUIRE(handle.ref_count() == 1);
    REQUIRE(handle.nbytes() > 4096);
  }
  std::this_thread::sleep_for(100ms);
  {
    libshm::shm_handle handle("test_shm", 1024);
    REQUIRE(handle.fd() > 0);
    REQUIRE(handle.addr() == nullptr);
    REQUIRE(handle.ref_count() == 1);
    REQUIRE(handle.nbytes() > 1024);
  }
  std::this_thread::sleep_for(100ms);
}

TEST_CASE("two process use the same shm obj (client/server)", "[constructor]")
{
  libshm::shm_handle server("test_shm", 4096);
  REQUIRE(server.fd() > 0);
  REQUIRE(server.addr() == nullptr);
  REQUIRE(server.ref_count() == 1);
  REQUIRE(server.nbytes() > 4096);

  libshm::shm_handle client1("test_shm");
  REQUIRE(client1.fd() > 0);
  REQUIRE(client1.addr() == nullptr);
  REQUIRE(client1.ref_count() == 2);
  REQUIRE(client1.nbytes() > 4096);

  REQUIRE(server.nbytes() == client1.nbytes());
  REQUIRE(server.ref_count() == client1.ref_count());

  libshm::shm_handle client2("test_shm");
  REQUIRE(client2.fd() > 0);
  REQUIRE(client2.addr() == nullptr);
  REQUIRE(client2.ref_count() == 3);
  REQUIRE(client2.nbytes() > 4096);

  libshm::shm_handle client3("test_shm");
  REQUIRE(client3.fd() > 0);
  REQUIRE(client3.addr() == nullptr);
  REQUIRE(client3.ref_count() == 4);
  REQUIRE(client3.nbytes() > 4096);

  SECTION("check when clients die, ref_count will decrease")
  {
    {
      libshm::shm_handle client_dying("test_shm");
      REQUIRE(server.ref_count() == 5);
      REQUIRE(client_dying.nbytes() == server.nbytes());
      REQUIRE(client_dying.ref_count() == server.ref_count());
    }
    REQUIRE(server.ref_count() == 4);

    {
      libshm::shm_handle client_dying1("test_shm");
      REQUIRE(server.ref_count() == 5);
      REQUIRE(client_dying1.nbytes() == server.nbytes());
      REQUIRE(client_dying1.ref_count() == server.ref_count());

      libshm::shm_handle client_dying2("test_shm");
      REQUIRE(server.ref_count() == 6);
      REQUIRE(client_dying2.nbytes() == server.nbytes());
      REQUIRE(client_dying2.ref_count() == server.ref_count());

      libshm::shm_handle client_dying3("test_shm");
      REQUIRE(server.ref_count() == 7);
      REQUIRE(client_dying3.nbytes() == server.nbytes());
      REQUIRE(client_dying3.ref_count() == server.ref_count());

      libshm::shm_handle client_dying4("test_shm");
      REQUIRE(server.ref_count() == 8);
      REQUIRE(client_dying4.nbytes() == server.nbytes());
      REQUIRE(client_dying4.ref_count() == server.ref_count());
    }
    REQUIRE(server.ref_count() == 4);
  }
}

TEST_CASE("when two server use the same shm name (server/server), throw!",
          "[constructor]")
{
  libshm::shm_handle server1("test_shm", 4096);
  REQUIRE(server1.fd() > 0);
  REQUIRE(server1.addr() == nullptr);
  REQUIRE(server1.ref_count() == 1);
  REQUIRE(server1.nbytes() > 4096);

  REQUIRE_THROWS(libshm::shm_handle("test_shm", 4096));
  REQUIRE_THROWS(libshm::shm_handle("test_shm", 2048));
}

TEST_CASE("map shared memory object into current process memory space.",
          "[map]")
{
  libshm::shm_handle handle("test_shm", 4096);
  REQUIRE(handle.fd() > 0);
  REQUIRE(handle.addr() == nullptr);
  REQUIRE(handle.ref_count() == 1);
  REQUIRE(handle.nbytes() > 4096);

  void* ptr = handle.map();
  REQUIRE(ptr != nullptr);
  REQUIRE(handle.addr() == ptr);

  void* ptr2 = handle.map();
  REQUIRE(ptr2 != nullptr);
  REQUIRE(ptr2 == ptr);

  std::error_code ec;
  void*           ptr3 = handle.map(ec);
  REQUIRE(ptr3 != nullptr);
  REQUIRE(ptr3 == ptr);
  REQUIRE(ec.value() == 0);
}

TEST_CASE("unmap shared memory object", "[unmap]")
{
  libshm::shm_handle handle("test_shm", 4096);
  REQUIRE(handle.fd() > 0);
  REQUIRE(handle.addr() == nullptr);
  REQUIRE(handle.ref_count() == 1);
  REQUIRE(handle.nbytes() > 4096);

  void* ptr = handle.map();
  REQUIRE(ptr != nullptr);
  REQUIRE(handle.addr() == ptr);
  handle.unmap();
  REQUIRE(handle.addr() == nullptr);

  std::error_code ec;
  void*           ptr2 = handle.map(ec);
  REQUIRE(ptr2 != nullptr);
  REQUIRE(ptr2 == handle.addr());
  REQUIRE(ec.value() == 0);
  handle.unmap();
  REQUIRE(handle.addr() == nullptr);
}

TEST_CASE("remap a shared memory obejct", "[map]")
{
  libshm::shm_handle handle("test_shm", 4096);
  REQUIRE(handle.fd() > 0);
  REQUIRE(handle.addr() == nullptr);
  REQUIRE(handle.ref_count() == 1);
  REQUIRE(handle.nbytes() > 4096);

  void* ptr = handle.map();
  REQUIRE(ptr != nullptr);
  REQUIRE(handle.addr() == ptr);
  handle.unmap();
  REQUIRE(handle.addr() == nullptr);
  REQUIRE(ptr != nullptr);

  // void* ptr2 = handle.map(ptr);
  // REQUIRE(ptr2 != nullptr);
  // REQUIRE(ptr2 == handle.addr());
  // REQUIRE(ptr2 == ptr);
  // handle.unmap();
  // REQUIRE(handle.addr() == nullptr);

  // std::error_code ec;

  // void* ptr3 = handle.map(ptr, ec);
  // REQUIRE(ptr3 != nullptr);
  // REQUIRE(ptr3 == handle.addr());
  // REQUIRE(ptr3 == ptr);
  // REQUIRE(ec.value() == 0);
  // handle.unmap();
  // REQUIRE(handle.addr() == nullptr);
}

TEST_CASE("two process interact the same buffer", "[map]")
{
  libshm::shm_handle server("test_shm", 4096);
  REQUIRE(server.addr() == nullptr);
  REQUIRE(server.nbytes() > 4096);
  REQUIRE(server.ref_count() == 1);

  double* svr_ptr = (double*)server.map();
  REQUIRE(svr_ptr != nullptr);
  REQUIRE(svr_ptr == server.addr());
  // store random number into the shared buffer
  std::mt19937                           engine;
  std::uniform_real_distribution<double> generator(0.0, 1.0);
  size_t                                 len = 4096 / 8;
  size_t                                 i;
  size_t                                 svr_sum = 0;
  for (i = 0; i < len; i++) {
    svr_ptr[i] = generator.operator()(engine);
    svr_sum += svr_ptr[i];
  }
  INFO("server side random number sum: " << svr_sum);

  // setup client
  libshm::shm_handle client("test_shm");
  REQUIRE(client.addr() == nullptr);
  REQUIRE(client.nbytes() > 4096);
  REQUIRE(client.ref_count() == 2);
  REQUIRE(server.ref_count() == client.ref_count());

  double* clt_ptr = static_cast<double*>(client.map());
  REQUIRE(clt_ptr != nullptr);
  REQUIRE(clt_ptr == client.addr());
  CHECK(clt_ptr != svr_ptr);
  // calculate sum;
  size_t clt_sum = 0;
  for (i = 0; i < len; i++) {
    clt_sum += clt_ptr[i];
  }
  INFO("client side random number sum: " << clt_sum);
  REQUIRE(svr_sum == Approx(clt_sum));
}

TEST_CASE("when unlink is called, new handle can't attach to it.", "[unlink]")
{
  libshm::shm_handle server("test_shm", 4096);
  REQUIRE(server.addr() == nullptr);
  REQUIRE(server.nbytes() > 4096);
  REQUIRE(server.ref_count() == 1);

  REQUIRE_NOTHROW(libshm::shm_handle("test_shm"));
  std::error_code ec;
  server.unlink(ec);
  REQUIRE(ec.value() == 0);
  REQUIRE(server.fd() == -1);
  REQUIRE_THROWS(libshm::shm_handle("test_shm"));
}